{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Import Packages\n",
    "Modify the system path and load the corresponding packages and functions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "ROOT = str(Path(\"__file__\").resolve().parents[1])\n",
    "sys.path.append(ROOT)\n",
    "import torch\n",
    "import argparse\n",
    "import os.path as osp\n",
    "from mmcv import Config\n",
    "from trademaster.utils import replace_cfg_vals\n",
    "from trademaster.nets.builder import build_net\n",
    "from trademaster.environments.builder import build_environment\n",
    "from trademaster.datasets.builder import build_dataset\n",
    "from trademaster.agents.builder import build_agent\n",
    "from trademaster.optimizers.builder import build_optimizer\n",
    "from trademaster.losses.builder import build_loss\n",
    "from trademaster.trainers.builder import build_trainer\n",
    "from trademaster.transition.builder import build_transition"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Load Configs\n",
    "Load default config from the folder `configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
    "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"high_frequency_trading\", \"high_frequency_trading_BTC_dqn_dqn_adam_mse.py\"),\n",
    "                    help=\"download datasets config file path\")\n",
    "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
    "parser.add_argument(\"--test_style\", type=str, default='-1')\n",
    "args = parser.parse_args([])\n",
    "cfg = Config.fromfile(args.config)\n",
    "task_name = args.task_name\n",
    "\n",
    "cfg = replace_cfg_vals(cfg)\n",
    "# update test style\n",
    "cfg.data.update({'test_style': args.test_style})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'HighFrequencyTradingDataset', 'data_path': 'data/high_frequency_trading/small_BTC', 'train_path': 'data/high_frequency_trading/small_BTC/train.csv', 'valid_path': 'data/high_frequency_trading/small_BTC/valid.csv', 'test_path': 'data/high_frequency_trading/small_BTC/test.csv', 'test_style_path': 'data/high_frequency_trading/small_BTC/test.csv', 'tech_indicator_list': ['imblance_volume_oe', 'sell_spread_oe', 'buy_spread_oe', 'kmid2', 'bid1_size_n', 'ksft2', 'ma_10', 'ksft', 'kmid', 'ask1_size_n', 'trade_diff', 'qtlu_10', 'qtld_10', 'cntd_10', 'beta_10', 'roc_10', 'bid5_size_n', 'rsv_10', 'imxd_10', 'ask5_size_n', 'ma_30', 'max_10', 'qtlu_30', 'imax_10', 'imin_10', 'min_10', 'qtld_30', 'cntn_10', 'rsv_30', 'cntp_10', 'ma_60', 'max_30', 'qtlu_60', 'qtld_60', 'cntd_30', 'roc_30', 'beta_30', 'bid4_size_n', 'rsv_60', 'ask4_size_n', 'imxd_30', 'min_30', 'max_60', 'imax_30', 'imin_30', 'cntd_60', 'roc_60', 'beta_60', 'cntn_30', 'min_60', 'cntp_30', 'bid3_size_n', 'imxd_60', 'ask3_size_n', 'sell_volume_oe', 'imax_60', 'imin_60', 'cntn_60', 'buy_volume_oe', 'cntp_60', 'bid2_size_n', 'kup', 'bid1_size', 'ask1_size', 'std_30', 'ask2_size_n'], 'transcation_cost': 0, 'backward_num_timestamp': 1, 'max_holding_number': 0.01, 'num_action': 11, 'max_punish': 1000000000000.0, 'episode_length': 14400, 'test_style': '-1'}, 'environment': {'type': 'HighFrequencyTradingEnvironment'}, 'agent': {'type': 'HighFrequencyTradingDDQN', 'auxiliary_coffient': 512, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.99, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'HighFrequencyTradingTrainer', 'epochs': 10, 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'seeds': 12345, 'batch_size': 512, 'horizon_len': 512, 'buffer_size': 100000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'HFTLoss', 'ada': 1}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'HFTQNet', 'state_dim': 66, 'action_dim': 11, 'dims': 16, 'explore_rate': 0.25, 'max_punish': 0}, 'cri': None, 'task_name': 'high_frequency_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'auxiliry_loss_name': 'KLdiv', 'net_name': 'high_frequency_trading_dqn', 'agent_name': 'ddqn', 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'batch_size': 512, 'train_environment': {'type': 'HighFrequencyTradingTrainingEnvironment'}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Build Dataset\n",
    "Build datasets from cfg defined above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = build_dataset(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.datasets.high_frequency_trading.dataset.HighFrequencyTradingDataset at 0x7f1fccd2a9d0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Build Reinforcement Learning Environments\n",
    "Build environments based on cfg and previously-defined dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid\n",
      "test\n"
     ]
    }
   ],
   "source": [
    "valid_environment = build_environment(\n",
    "        cfg, default_args=dict(dataset=dataset, task=\"valid\")\n",
    "    )\n",
    "test_environment = build_environment(\n",
    "    cfg, default_args=dict(dataset=dataset, task=\"test\")\n",
    ")\n",
    "cfg.environment = cfg.train_environment\n",
    "train_environment = build_environment(\n",
    "    cfg, default_args=dict(dataset=dataset, task=\"train\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.high_frequency_trading.environment.HighFrequencyTradingTrainingEnvironment at 0x7f1fccd47150>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.high_frequency_trading.environment.HighFrequencyTradingEnvironment at 0x7f1fccd30990>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.high_frequency_trading.environment.HighFrequencyTradingEnvironment at 0x7f1fccd5f8d0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_environment"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Build Net \n",
    "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_dim = train_environment.action_dim\n",
    "state_dim = train_environment.state_dim\n",
    "\n",
    "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
    "act = build_net(cfg.act)\n",
    "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n",
    "if cfg.cri:\n",
    "    cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
    "    cri = build_net(cfg.cri)\n",
    "    cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n",
    "else:\n",
    "    cri = None\n",
    "    cri_optimizer = None"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Build Loss\n",
    "Build loss from config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = build_loss(cfg)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: Build Agent\n",
    "Build agent from config and detect device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "agent = build_agent(\n",
    "        cfg,\n",
    "        default_args=dict(\n",
    "            action_dim=action_dim,\n",
    "            state_dim=state_dim,\n",
    "            act=act,\n",
    "            cri=cri,\n",
    "            act_optimizer=act_optimizer,\n",
    "            cri_optimizer=cri_optimizer,\n",
    "            criterion=criterion,\n",
    "            device=device,\n",
    "        ),\n",
    "    )\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 8: Build Trainer\n",
    "Build trainer from config and create work directionary to save the result, model and config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse\n"
     ]
    }
   ],
   "source": [
    "\n",
    "trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
    "                                                valid_environment=valid_environment,\n",
    "                                                test_environment=test_environment,\n",
    "                                                agent=agent,\n",
    "                                                device=device))\n",
    "\n",
    "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 9: Train the Trainer\n",
    "Train the trainer based on the config and get results from workdir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "Train Episode: [1/10]\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be sell all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n",
      "the holding could not be bought all clear\n"
     ]
    }
   ],
   "source": [
    "if task_name.startswith(\"train\"):\n",
    "    trainer.train_and_valid()\n",
    "    trainer.test()\n",
    "    print(\"train end\")\n",
    "elif task_name.startswith(\"test\"):\n",
    "    trainer.test()\n",
    "    print(\"test end\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HFT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
